{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Compare Stochastic learning strategies for MLPClassifier\n\n\nThis example visualizes some training loss curves for different stochastic\nlearning strategies, including SGD and Adam. Because of time-constraints, we\nuse several small datasets, for which L-BFGS might be more suitable. The\ngeneral trend shown in these examples seems to carry over to larger datasets,\nhowever.\n\nNote that those results can be highly dependent on the value of\n``learning_rate_init``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(__doc__)\n\nimport warnings\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn.neural_network import MLPClassifier\nfrom sklearn.preprocessing import MinMaxScaler\nfrom sklearn import datasets\nfrom sklearn.exceptions import ConvergenceWarning\n\n# different learning rate schedules and momentum parameters\nparams = [{'solver': 'sgd', 'learning_rate': 'constant', 'momentum': 0,\n           'learning_rate_init': 0.2},\n          {'solver': 'sgd', 'learning_rate': 'constant', 'momentum': .9,\n           'nesterovs_momentum': False, 'learning_rate_init': 0.2},\n          {'solver': 'sgd', 'learning_rate': 'constant', 'momentum': .9,\n           'nesterovs_momentum': True, 'learning_rate_init': 0.2},\n          {'solver': 'sgd', 'learning_rate': 'invscaling', 'momentum': 0,\n           'learning_rate_init': 0.2},\n          {'solver': 'sgd', 'learning_rate': 'invscaling', 'momentum': .9,\n           'nesterovs_momentum': True, 'learning_rate_init': 0.2},\n          {'solver': 'sgd', 'learning_rate': 'invscaling', 'momentum': .9,\n           'nesterovs_momentum': False, 'learning_rate_init': 0.2},\n          {'solver': 'adam', 'learning_rate_init': 0.01}]\n\nlabels = [\"constant learning-rate\", \"constant with momentum\",\n          \"constant with Nesterov's momentum\",\n          \"inv-scaling learning-rate\", \"inv-scaling with momentum\",\n          \"inv-scaling with Nesterov's momentum\", \"adam\"]\n\nplot_args = [{'c': 'red', 'linestyle': '-'},\n             {'c': 'green', 'linestyle': '-'},\n             {'c': 'blue', 'linestyle': '-'},\n             {'c': 'red', 'linestyle': '--'},\n             {'c': 'green', 'linestyle': '--'},\n             {'c': 'blue', 'linestyle': '--'},\n             {'c': 'black', 'linestyle': '-'}]\n\n\ndef plot_on_dataset(X, y, ax, name):\n    # for each dataset, plot learning for each learning strategy\n    print(\"\\nlearning on dataset %s\" % name)\n    ax.set_title(name)\n\n    X = MinMaxScaler().fit_transform(X)\n    mlps = []\n    if name == \"digits\":\n        # digits is larger but converges fairly quickly\n        max_iter = 15\n    else:\n        max_iter = 400\n\n    for label, param in zip(labels, params):\n        print(\"training: %s\" % label)\n        mlp = MLPClassifier(random_state=0,\n                            max_iter=max_iter, **param)\n\n        # some parameter combinations will not converge as can be seen on the\n        # plots so they are ignored here\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", category=ConvergenceWarning,\n                                    module=\"sklearn\")\n            mlp.fit(X, y)\n\n        mlps.append(mlp)\n        print(\"Training set score: %f\" % mlp.score(X, y))\n        print(\"Training set loss: %f\" % mlp.loss_)\n    for mlp, label, args in zip(mlps, labels, plot_args):\n        ax.plot(mlp.loss_curve_, label=label, **args)\n\n\nfig, axes = plt.subplots(2, 2, figsize=(15, 10))\n# load / generate some toy datasets\niris = datasets.load_iris()\nX_digits, y_digits = datasets.load_digits(return_X_y=True)\ndata_sets = [(iris.data, iris.target),\n             (X_digits, y_digits),\n             datasets.make_circles(noise=0.2, factor=0.5, random_state=1),\n             datasets.make_moons(noise=0.3, random_state=0)]\n\nfor ax, data, name in zip(axes.ravel(), data_sets, ['iris', 'digits',\n                                                    'circles', 'moons']):\n    plot_on_dataset(*data, ax=ax, name=name)\n\nfig.legend(ax.get_lines(), labels, ncol=3, loc=\"upper center\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}